-
-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dynamic_inference update #60
dynamic_inference update #60
Conversation
@tpapp can you please take a look? I have made some changes, I am stuck at how to use the new transformations interface to let the user pass the transformations as an argument to the function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to clarify a bit, happy to help more.
test/dynamicHMC.jl
Outdated
@@ -16,8 +16,8 @@ t = collect(range(1,stop=10,length=10)) # observation times | |||
sol = solve(prob1,Tsit5()) | |||
randomized = VectorOfArray([(sol(t[i]) + σ * randn(2)) for i in 1:length(t)]) | |||
data = convert(Array,randomized) | |||
|
|||
bayesian_result = dynamichmc_inference(prob1, Tsit5(), t, data, [Normal(1.5, 1)], [bridge(ℝ, ℝ⁺, )]) | |||
transform = (a = asℝ₊) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this need a ,
to make a (named) tuple, ie (a = asℝ₊, )
.
src/dynamichmc_inference.jl
Outdated
@@ -33,7 +33,7 @@ function dynamichmc_inference(prob::DiffEqBase.DEProblem, alg, t, data, priors, | |||
kwargs...) | |||
likelihood = sol -> sum( sum(logpdf.(Normal(0.0, σ), sol(t) .- data[:, i])) | |||
for (i, t) in enumerate(t) ) | |||
|
|||
println(typeof(transformations)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why you are printing it, but that's a minor thing.
Regarding the API: perhaps I am missing something, but I would make transformations part of the problem, ie in DynamicHMCPosterior
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you clarify how that would work, I also feel that adding the transformation like I am doing is not really going to work (I actually copied this right off your linear regression example according to my understanding).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could also make the transformation a slot in the composite type. But with this interface, it may not be needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see, the does sound good, but then how would the applied is not clear to me, I can imagine that with the ContinuousTransformations
interface, I'll push with what I think would work here but I doubt it will be correct 😅
src/dynamichmc_inference.jl
Outdated
@@ -43,39 +43,35 @@ function dynamichmc_inference(prob::DiffEqBase.DEProblem, alg, likelihood, prior | |||
ϵ=0.001, initial=Float64[], num_samples=1000, | |||
kwargs...) | |||
P = DynamicHMCPosterior(alg, prob, likelihood, priors, kwargs) | |||
println(typeof(transformations)) | |||
prob_transform(P::DynamicHMCPosterior) = as((transformations)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for as(...)
here, that's the api for making transformations. I would just make the user to that and work with the transformation object from then on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I felt this would give a cleaner interface, this can be changed later on too though so I'll keep this in mind for sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I am saying is that defining a function just for the sole purpose of returning the transformations
makes little sense. I would do something like
function dynamichmc_inference(prob::DiffEqBase.DEProblem, alg, likelihood, priors, transformations;
ϵ=0.001, initial=Float64[], num_samples=1000,
kwargs...)
P = DynamicHMCPosterior(alg, prob, likelihood, priors, kwargs)
PT = TransformedLogDensity(transformations, P)
PTG = FluxGradientLogDensity(PT);
chain, NUTS_tuned = NUTS_init_tune_mcmc(PTG,num_samples, ϵ=ϵ)
posterior = transform.(Ref(PTG.transformation), get_position.(chain));
return posterior, chain, NUTS_tuned
end
instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see.
809e753
to
abeb9fa
Compare
Also in addition (because this didn't change so I can't comment on it), if !isfinite(ℓ) && (ℓ ≠ -Inf)
ℓ = -Inf # protect against NaN etc, is it needed?
end is the right way to do it, as the function needs to return a finite real or |
It can be required in the ODE systems because for certain parameters the ODE may only have a well-defined solution for a finite time before it blows up. In these cases, we want to associate the cost of these parameters to be infinity. If we are discussing this as the likelihood, then we want to assign zero likelihood of having these as the solution, so -Inf in the log likelihood sounds correct. |
test/dynamicHMC.jl
Outdated
@@ -16,7 +16,9 @@ t = collect(range(1,stop=10,length=10)) # observation times | |||
sol = solve(prob1,Tsit5()) | |||
randomized = VectorOfArray([(sol(t[i]) + σ * randn(2)) for i in 1:length(t)]) | |||
data = convert(Array,randomized) | |||
transform = (a = asℝ₊,) | |||
function transform(p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess I was not clear (sorry): I would just design the API so that the user can provide the transformation directly, eg as
bayesian_result = dynamichmc_inference(prob1, Tsit5(), t, data, [Normal(1.5, 1)], as((a = asℝ₊,))
as transformations are already callable.
Apologies for the documentation of the related packages being so WIP, PRs are welcome.
For typing issues, see https://discourse.julialang.org/t/differentialequations-jl-flux-jl-or-knet-jl/8208/9?u=chrisrackauckas We may want to run a difference AD through the solver than the Flux Tracker types. I'm not sure what causes this since the ReverseDiff.jl tracker types work just fine. This could potentially be the issue, but I'd test this in a quick script first before thinking that. |
SciML/DiffEqBase.jl@5184e94 is what we necessary to make Tracked values work. Upped the DiffEqBase min version to 4.28.1 for it. |
The typing issues are fixed but for some reason the sampling is all in a weird part of parameter space. @tpapp is there a good way to test that the gradients are correct? |
@ChrisRackauckas: you can compare the results from eg Flux and ForwardDiff, or to finite differences. |
Also, consider |
I couldn't find out where the derivative choice is defined. Is this in the docs somewhere? Or, how are you seeding the Flux Tracker types to do the AD? We can test Flux's AD here: https://github.com/JuliaDiffEq/OrdinaryDiffEqExtendedTests.jl/blob/master/test/autodiff_events.jl or on this: https://github.com/JuliaDiffEq/DiffEqSensitivity.jl/blob/master/test/adjoint.jl#L31-L40 but I'm not familiar with how to seed Flux for that. |
It's the choice of a wrapper, in particular in the code of this PR the line |
Tests show that Flux's AD does indeed work on the diffeq solver: SciML/SciMLSensitivity.jl@b16489d . So I'm not sure why it's sampling the way it does. |
Alright, the DiffEq side of this should be solid now. @tpapp are we not extracting the chain correctly? For some reason, |
Looks OK at the first glance, let me look at it in detail (a bit busy now, may take a few days). |
@tpapp I was trying to find how to pass larger number of steps for |
REQUIRE
Outdated
Distances | ||
ApproxBayes | ||
TransformVariables | ||
LogDensityProblems | ||
Flux |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can drop the flux dep
One way to avoid local minima might be to pass an initial value (I think Stan and Turing allow for that) which we can find from running an optimizer, is that something that we could pursue here? |
A local optimizer will find a local optimum. Even then, it would just find the posterior around one optimum which is less than ideal. |
@ChrisRackauckas: no, the initial point just matters for adaptation of the stepsize. Ideally it should be irrelevant. When there is something funny going on, it adapt the stepsize badly and the tuner steps may not be sufficient for recovery. The algorithm still explores the whole posterior, just slower. @Vaibhavdixit02: since you are working with simulated data, the quick & dirty solution would be just starting at the known parameter value. If that solves the issue, we can explore adaptation separately. |
I have created a self-contained MWE for this problem here. I would suggest we work on this to solve the problem. NOTE: use I run into the following problems:
julia> chain, nuts = NUTS_init_tune_mcmc(∇P, 1000); q = inverse(t, (a = a₀,))
┌ Warning: Interrupted. Larger maxiters is needed.
└ @ DiffEqBase ~/.julia/packages/DiffEqBase/dxcbq/src/integrator_interface.jl:124
ERROR: BoundsError: attempt to access 1-element Array{Float64,1} at index [2]
Stacktrace:
[1] getindex(::Array{Float64,1}, ::Int64) at ./array.jl:739
[2] ode_interpolation(::Float64, ::Function, ::Nothing, ::Type, ::NamedTuple{(:a,),Tuple{ForwardDiff.Dual{ForwardDiff.Tag{getfield(LogDensityProblems, Symbol("##1#2")){TransformedLogDensity{TransformVariables.TransformNamedTuple{(:a,),Tuple{TransformVariables.ShiftedExp{true,Float64}}},BayesianODEProblem{typeof(parameterized_lotkavolterra),Array{Float64,1},Tuple{Float64,Float64},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},typeof(prior_lotkavolterra),MvNormal{Float64,PDMats.PDiagMat{Float64,Array{Float64,1}},Array{Float64,1}},Array{Array{Float64,1},1}}}},Float64},Float64,1}}}, ::Symbol) at /home/tamas/.julia/packages/OrdinaryDiffEq/S005Q/src/dense/generic_dense.jl:238
[3] InterpolationData at /home/tamas/.julia/packages/OrdinaryDiffEq/S005Q/src/interp_func.jl:70 [inlined]
[4] #call#6 at /home/tamas/.julia/packages/DiffEqBase/dxcbq/src/solutions/ode_solutions.jl:14 [inlined]
[5] ODESolution at /home/tamas/.julia/packages/DiffEqBase/dxcbq/src/solutions/ode_solutions.jl:14 [inlined] (repeats 2 times) am I using the ODE framework in the wrong way?
ERROR: Solution interpolation cannot extrapolate past the final timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface. which is why I am adding the |
I can't run the MWE since ∇P = ADgradient(Val(:ForwardDiff), P) is undefined.
Only turning off adaptivity, but that's really not an option in most cases. Decreasing the tolerance reduces the size of the jumps.
Not necessarily. It can be very easy in these kinds of models to choose parameters which done have a convergent solution. This makes parameter estimation difficult since there's unknown nonlinear boundaries at which the solution is no longer well-behaved, hence the Inf handling. For example, with this problem there is a nonlinear boundary past which the solution is no longer cyclic and instead diverges to infinity or converges to zero. What parameter value did that happen for?
That shouldn't be a solution to it. Most likely it's the same problem as the second. The solution exits early due to some divergence, and but this logdensity implementation doesn't handle the possibility that the solution can halt early. |
as I said, use |
I checked it out by adding a print to the parameters: function (problem::BayesianODEProblem)(parameters)
@unpack f, u, timespan, timepoints, logprior, noise, data = problem
val,par = ForwardDiff.value(parameters.a),ForwardDiff.partials(parameters.a)
@show val,par
u_widened = typeof(parameters.a).(u) # NOTE: UGLY HACK
odeproblem = ODEProblem(f, u_widened, timespan, parameters)
solution = solve(odeproblem, Tsit5())
loglikelihood = sum(logpdf(noise, d .- solution(timepoint))
for (d, timepoint) in zip(data, timepoints))
out = loglikelihood + logprior(parameters)
outval,outpar = ForwardDiff.value(out),ForwardDiff.partials(out)
@show outval,outpar
out
end and got: (val, par) = (0.5792652917385154, Partials(0.5792652917385154,))
(outval, outpar) = (-537.4348052283677, Partials(-1741.3026430313244,))
(val, par) = (0.0, Partials(0.0,))
(outval, outpar) = (-585.7773541564096, Partials(-0.0,))
(val, par) = (1.3739641428174652e-95, Partials(1.3739641428174652e-95,))
(outval, outpar) = (-585.7773541564096, Partials(9.831235249322356e-93,))
(val, par) = (1.1729953426768217e-24, Partials(1.1729953426768217e-24,))
(outval, outpar) = (-585.7773541564096, Partials(8.393227159893489e-22,))
(val, par) = (6.619182513371314e-7, Partials(6.619182513371314e-7,))
(outval, outpar) = (-585.7768805280341, Partials(0.0004736291058905471,))
(val, par) = (0.018536182489576378, Partials(0.018536182489576378,))
(outval, outpar) = (-571.9417595700329, Partials(14.406013840303185,))
(val, par) = (0.24237830445443795, Partials(0.24237830445443795,))
(outval, outpar) = (-559.3999863392875, Partials(-941.891584027233,))
(val, par) = (0.4633901899822291, Partials(0.4633901899822291,))
(outval, outpar) = (-407.6337648555079, Partials(337.58332127165414,))
(val, par) = (0.3534154992627069, Partials(0.3534154992627069,))
(outval, outpar) = (-477.6005447011211, Partials(-539.4992294059473,))
(val, par) = (0.41009317286490954, Partials(0.41009317286490954,))
(outval, outpar) = (-476.67711727739515, Partials(1512.7659178377664,))
(val, par) = (0.3819668269900791, Partials(0.3819668269900791,))
(outval, outpar) = (-535.9210932827468, Partials(-234.25225637410773,))
(val, par) = (0.39610879786373776, Partials(0.39610879786373776,))
(outval, outpar) = (-524.082043526718, Partials(949.9882432727143,))
(val, par) = (0.4031239610986336, Partials(0.4031239610986336,))
(outval, outpar) = (-502.6989544856027, Partials(1449.4138180147552,))
(val, par) = (0.399621711377083, Partials(0.399621711377083,))
(outval, outpar) = (-514.4384523580005, Partials(1235.9574520570047,))
(val, par) = (0.39786653629399304, Partials(0.39786653629399304,))
(outval, outpar) = (-519.5648671449208, Partials(1099.1941827044236,))
(val, par) = (0.39874445065687203, Partials(0.39874445065687203,))
(outval, outpar) = (-517.0733769183182, Partials(1169.527605742951,))
MCMC, adapting ϵ (75 steps)
(val, par) = (0.5792652917385154, Partials(0.5792652917385154,))
(outval, outpar) = (-537.4348052283677, Partials(-1741.3026430313244,))
(val, par) = (0.4060322707948085, Partials(0.4060322707948085,))
(outval, outpar) = (-491.9285388343783, Partials(1539.5350410718802,))
(val, par) = (0.4060322707948085, Partials(0.4060322707948085,))
(outval, outpar) = (-491.9285388343783, Partials(1539.5350410718802,))
(val, par) = (1.4203721487858283e28, Partials(1.4203721487858283e28,))
┌ Warning: Interrupted. Larger maxiters is needed.
└ @ DiffEqBase C:\Users\Chris\.julia\dev\DiffEqBase\src\integrator_interface.jl:124 The issues with divergence are because the parameters go bonkers after a few runs. Why would they diverge like that? The partials all seem fine in the likelihood. |
(val, par) = (0.9600603566426819, Partials(0.9600603566426819,))
(outval, outpar) = (-496.2876428092132, Partials(-298.463691885259,))
(val, par) = (5.779863369367498e-66, Partials(5.779863369367498e-66,))
(outval, outpar) = (-585.7773541564096, Partials(4.135711749847459e-63,))
(val, par) = (3.31484364694372e-17, Partials(3.31484364694372e-17,))
(outval, outpar) = (-585.7773541564096, Partials(2.3718965213311913e-14,))
(val, par) = (6.144259381762237e-5, Partials(6.144259381762237e-5,))
(outval, outpar) = (-585.7333833632036, Partials(0.043977077322872024,))
(val, par) = (0.07846143128753674, Partials(0.07846143128753674,))
(outval, outpar) = (-520.1067611943706, Partials(73.16280660175568,))
(val, par) = (0.4906766915911269, Partials(0.4906766915911269,))
(outval, outpar) = (-374.7103811922104, Partials(695.6249021884687,))
(val, par) = (0.2268847673020799, Partials(0.2268847673020799,))
(outval, outpar) = (-506.0826684156832, Partials(-637.0283434483251,))
(val, par) = (0.345995349573582, Partials(0.345995349573582,))
(outval, outpar) = (-473.6180228672568, Partials(88.41409066765786,))
(val, par) = (0.41579122008851077, Partials(0.41579122008851077,))
(outval, outpar) = (-457.40737664079535, Partials(1270.5276558994913,))
(val, par) = (0.45271093917067545, Partials(0.45271093917067545,))
(outval, outpar) = (-414.0449055080945, Partials(234.14024480277817,))
(val, par) = (0.4341047680169557, Partials(0.4341047680169557,))
(outval, outpar) = (-425.1255897731216, Partials(376.1099374175511,))
(val, par) = (0.42490959321640964, Partials(0.42490959321640964,))
(outval, outpar) = (-436.0494617191355, Partials(721.4983953577138,))
(val, par) = (0.42949780277014654, Partials(0.42949780277014654,))
(outval, outpar) = (-429.6608366674312, Partials(516.8025724185654,))
(val, par) = (0.4272013253908379, Partials(0.4272013253908379,))
(outval, outpar) = (-432.5736280049785, Partials(608.1304496434184,))
(val, par) = (0.4260548626655953, Partials(0.4260548626655953,))
(outval, outpar) = (-434.2364519752774, Partials(661.3747972619409,))
(val, par) = (0.42548207834745694, Partials(0.42548207834745694,))
(outval, outpar) = (-435.12375647714777, Partials(690.4352950565778,))
(val, par) = (0.4257684331623101, Partials(0.4257684331623101,))
(outval, outpar) = (-434.6753444205647, Partials(675.6803777654867,))
(val, par) = (0.4259116385846774, Partials(0.4259116385846774,))
(outval, outpar) = (-434.4547147271462, Partials(668.4735785198479,))
(val, par) = (0.4259832482936653, Partials(0.4259832482936653,))
(outval, outpar) = (-434.3452883733769, Partials(664.910914739759,))
(val, par) = (0.42601905489686853, Partials(0.42601905489686853,))
(outval, outpar) = (-434.2907965452112, Partials(663.1395643321872,))
MCMC, adapting ϵ (75 steps)
(val, par) = (0.9600603566426819, Partials(0.9600603566426819,))
(outval, outpar) = (-496.2876428092132, Partials(-298.463691885259,))
(val, par) = (0.4886401173492016, Partials(0.4886401173492016,))
(outval, outpar) = (-377.65226199108514, Partials(715.971513298252,))
(val, par) = (0.45358898754030974, Partials(0.45358898754030974,))
(outval, outpar) = (-413.5820944112558, Partials(241.4547971895248,))
(val, par) = (0.6847192765742149, Partials(0.6847192765742149,))
(outval, outpar) = (-511.7987225061569, Partials(-79.689106121853,))
(val, par) = (0.4886401173492016, Partials(0.4886401173492016,))
(outval, outpar) = (-377.65226199108514, Partials(715.971513298252,))
(val, par) = (2.971617284436392e45, Partials(2.971617284436392e45,))
┌ Warning: Interrupted. Larger maxiters is needed.
└ @ DiffEqBase C:\Users\Chris\.julia\dev\DiffEqBase\src\integrator_interface.jl:124 it looks like it does this whenever it repeats a parameter value. Notice how above it does |
Thanks for the detailed analysis. Yes, it does look fishy, I am currently investigating this. That said, the likelihood should be defined for all valid parameter values. Large values are easy to reach in this case, as we use the exponential transformation to constrain julia> log(1.4203721487858283e28)
64.82330151772011 which is not that crazy. A quick & dirty fix is to constrain But it would be better to somehow detect he explosion and return a |
That's what we do in the PR but was left out of the MWE :). That's what: if any((s.retcode != :Success for s in sol))
ℓ = -Inf
else
ℓ = likelihood(sol)
end does. The retcode lets you know if the solution was successful, and this also handles the fact that the number of saved datapoints won't match the observations. It sounds like we should replace this with throwing
That's probably the better solution. Any A combination of these is likely required for any real usage. |
I have been debugging this further, and think that the kinetic energy is somehow maladapted. At the moment I am not sure why. I will keep working on this and report back. |
2ae06fc
to
be19cb1
Compare
be19cb1
to
1074a7f
Compare
src/dynamichmc_inference.jl
Outdated
sol = solve(prob, alg; kwargs...) | ||
if any((s.retcode != :Success for s in sol)) | ||
ℓ = -Inf | ||
ℓ = LogDensityProblems.reject_logdensity() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tpapp is this a correct use of reject_logdensity()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope. Call it as a function, it will throw an exception that is caught by the wrappers and converted to -Inf
. Or you can return -Inf
directly, if you don't want to unwind the stack.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, then I'll just return -Inf
This is still failing the test but nothing errors on usage. It seems to be an issue with the sampler so I'm just going to merge since this still fixes usage in some sense. |
No description provided.